from typing import Tuple

import numpy as np
import torch
import torch.utils.data as data
import random

from auxiliary.utils import hwc_chw, gamma_correct, brg_to_rgb
from classes.data.DataAugmenter import DataAugmenter


class BaseDataset(data.Dataset):

    def __init__(self, mode, input_size, device):
        self.__device = device
        self.__input_size = input_size
        self.__da = DataAugmenter(self.__input_size)
        self._mode = mode
        #self._data_dir, self._label_dir = "/dataset", "/NPY"
        self._paths_to_seqs = []
        self._nums_to_seqs = []

    def __getitem__(self, index: int) -> Tuple:
        path_to_sequence = self._paths_to_seqs[index]
        num_to_sequence = self._nums_to_seqs[index]
        label_path = path_to_sequence + '/illu_mat.npy'
        illums = np.load(label_path, allow_pickle=True).item()
        id = random.randint(1, num_to_sequence)
        files_seq = []
        if id == 1:
            files_seq.append(path_to_sequence+'/'+str(id)+'.dng.npy')
            files_seq.append(path_to_sequence+'/'+str(id)+'.dng.npy')
            files_seq.append(path_to_sequence+'/'+str(id+1)+'.dng.npy')     
        elif id == num_to_sequence:
            files_seq.append(path_to_sequence+'/'+str(id-1)+'.dng.npy')
            files_seq.append(path_to_sequence+'/'+str(id)+'.dng.npy')
            files_seq.append(path_to_sequence+'/'+str(id)+'.dng.npy')  
        else:
            files_seq.append(path_to_sequence+'/'+str(id-1)+'.dng.npy')
            files_seq.append(path_to_sequence+'/'+str(id)+'.dng.npy')
            files_seq.append(path_to_sequence+'/'+str(id+1)+'.dng.npy')  
        images = [np.array(np.load(file), dtype='float32') for file in files_seq]
        seq = np.array(images, dtype='float32')
        illuminant = np.array(illums[str(id)], dtype='float32')

        mimic = torch.from_numpy(self.__da.augment_mimic(seq).transpose((0, 3, 1, 2)).copy())

        if self._mode == "train":
            seq, color_bias = self.__da.augment_sequence(seq, illuminant)
            color_bias = np.array([[[color_bias[0][0], color_bias[1][1], color_bias[2][2]]]], dtype=np.float32)
            mimic = torch.mul(mimic, torch.from_numpy(color_bias).view(1, 3, 1, 1))
        else:
            seq = self.__da.resize_sequence(seq)

        seq = np.clip(seq, 0.0, 255.0) * (1.0 / 255)
        seq = hwc_chw(gamma_correct(brg_to_rgb(seq)))

        seq = torch.from_numpy(seq.copy())
        illuminant = torch.from_numpy(illuminant.copy())

        return seq, mimic, illuminant, path_to_sequence

    def __len__(self) -> int:
        return len(self._paths_to_seqs)
